Skip to content

Conversation

@hfrick
Copy link
Member

@hfrick hfrick commented Jan 26, 2023

Closes #859

The corresponding test in extratest is update here:

This PR is to be merged before the one on extratests.

library(parsnip)
data("hpc_data", package = "modeldata", envir = rlang::current_env())

mr_spec <- multinom_reg(penalty = 0.1) %>% set_engine("glmnet")
f_fit <- fit(mr_spec, class ~ protocol + log(compounds) + input_fields,
             data = hpc_data)
f_pred_class <- multi_predict(f_fit, hpc_data, penalty = c(0.01, 0.1), 
                              type = "class")
f_pred_class$.pred[[1]]
#> # A tibble: 2 × 2
#>   penalty .pred_class
#>     <dbl> <fct>      
#> 1    0.01 VF         
#> 2    0.1  VF

f_pred_prob <- multi_predict(f_fit, hpc_data, penalty = c(0.01, 0.1), 
                             type = "prob")
f_pred_prob$.pred[[1]]
#> # A tibble: 2 × 5
#>   penalty .pred_VF .pred_F .pred_M .pred_L
#>     <dbl>    <dbl>   <dbl>   <dbl>   <dbl>
#> 1    0.01    0.464   0.287   0.174  0.0758
#> 2    0.1     0.488   0.325   0.124  0.0625

Created on 2023-01-26 with reprex v2.0.2

@hfrick hfrick changed the title put penalty column to the left Switch column order of multi_predict() result for multinom_reg(engine = "glmnet") Jan 26, 2023
@hfrick hfrick requested a review from EmilHvitfeldt January 26, 2023 14:08
@hfrick hfrick merged commit 0c8893d into main Feb 7, 2023
@hfrick hfrick deleted the glmnet-multi_predict-column-order branch February 7, 2023 11:33
@github-actions
Copy link
Contributor

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Feb 22, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Switch column order of multi_predict() result for multinom_reg(engine = "glmnet")

3 participants